{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "blR8hhA_e-4S"
      },
      "source": [
        "# LongLLaMA Code: Focused Transformer Training for Context Scaling\n",
        "**LongLLaMA-Code-Instruct-7B is a LLM that builds upon the foundation of CodeLLaMA using FoT context extension and instruction tuning**.\n",
        "\n",
        "It is created by first taking the [CodeLLaMA-7B](https://huggingface.co/codellama/CodeLlama-7b-hf), continuing pre-training with [Focused Transformer (FoT)](https://arxiv.org/abs/2307.03170) method, and then instruction tuning the model.\n",
        "\n",
        "This notebook is a demo of [LongLLaMA-Code-Instruct-7B](TODO).  \n",
        "To create the model we have used the following datasets:\n",
        "* [MathInstruct](https://huggingface.co/datasets/TIGER-Lab/MathInstruct)\n",
        "* [OpenOrca](https://huggingface.co/datasets/Open-Orca/OpenOrca)\n",
        "* [ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed)\n",
        "\n",
        "For more, see the [FoT paper](https://arxiv.org/abs/2307.03170) and the [GitHub repository](https://github.com/CStanKonrad/long_llama).  \n",
        "Note that as we have started from the Meta's [CodeLLaMA-7B](https://huggingface.co/codellama/CodeLlama-7b-hf) model and used GPT outputs to finetune it, the model is aimed for research purposes only. \n",
        " \n",
        "\n",
        "**We provide the basic code for quantization that should suffice to run most of the demo parts on the free Colab GPU.**  \n",
        "**However, the quantization code is not optimized and may result in reduced performance.**  \n",
        "**Generation with free colab GPU may be slow, but processing the long input should take 1-2 min**  \n",
        "**Running this model without quantization may require an A100 40GB GPU.**  "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZTROtgpafB_8"
      },
      "source": [
        "# Initial steps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "La3cPHfCe7hU"
      },
      "outputs": [],
      "source": [
        "!pip install --upgrade pip\n",
        "!pip install transformers==4.33.2 sentencepiece accelerate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0x_utNYxfECA"
      },
      "outputs": [],
      "source": [
        "import gc\n",
        "import numpy as np\n",
        "import torch\n",
        "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, PreTrainedModel, PreTrainedTokenizer\n",
        "from typing import List, Optional"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LrrFlXKMfHMs"
      },
      "outputs": [],
      "source": [
        "MODEL_PATH = \"syzymon/long_llama_code_7b_instruct\"\n",
        "TOKENIZER_PATH = MODEL_PATH\n",
        "# to reduce GPU memory usage we will use reduced precision\n",
        "TORCH_DTYPE = torch.bfloat16\n",
        "\n",
        "if torch.cuda.is_available():\n",
        "    device = torch.device(\"cuda\")\n",
        "else:\n",
        "    device = torch.device(\"cpu\")\n",
        "\n",
        "print(device)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# To fit most of the demo parts on a single Google Colab GPU we\n",
        "# provide a basic unoptimized quantization code\n",
        "# change to False to disable the quantization\n",
        "QUANTIZED = True"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5fE5Z1ABfJUp"
      },
      "outputs": [],
      "source": [
        "tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)\n",
        "\n",
        "# unoptimized quantization code for running with free Colab GPU\n",
        "def load_and_qunatize_model(num_bit: int, model_path):\n",
        "    print(f\"!!!!!WARNING!!!!! The mode will be quantized to {num_bit} bits!\\n\"\n",
        "          \"This may affect the model performance!\")\n",
        "    \n",
        "    !pip3 install huggingface_hub\n",
        "    !pip3 install bitsandbytes\n",
        "    !git clone https://github.com/CStanKonrad/long_llama.git\n",
        "    !cp -r long_llama/src long_llama_code/\n",
        "    from long_llama_code.modeling_longllama import LongLlamaForCausalLM\n",
        "    from long_llama_code.configuration_longllama import LongLlamaConfig\n",
        "    from transformers import AutoConfig\n",
        "    from accelerate.utils import BnbQuantizationConfig\n",
        "    from accelerate.utils import load_and_quantize_model\n",
        "    from accelerate import init_empty_weights\n",
        "    from huggingface_hub import snapshot_download, hf_hub_download\n",
        "\n",
        "    \n",
        "    cfg = LongLlamaConfig.from_pretrained(model_path)\n",
        "    cfg.mem_attention_grouping = (1, 1024)\n",
        "    with init_empty_weights():\n",
        "        empty_model = LongLlamaForCausalLM(cfg)\n",
        "\n",
        "    gc.collect()\n",
        "    if num_bit == 8:\n",
        "        weights_loc = hf_hub_download(repo_id=MODEL_PATH, filename=\"quantized/pytorch_model_8bit.bin\")\n",
        "        bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, llm_int8_threshold = 6)\n",
        "    elif num_bit == 4:\n",
        "        # May give out of RAM on Colab\n",
        "        weights_loc = snapshot_download(MODEL_PATH) #MODEL_PATH\n",
        "        bnb_quantization_config = BnbQuantizationConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)\n",
        "    else:\n",
        "        raise ValueError(f\"{num_bit} quantization not supported.\")\n",
        "\n",
        "    gc.collect()\n",
        "    model = load_and_quantize_model(empty_model, weights_location=weights_loc, bnb_quantization_config=bnb_quantization_config, device_map=\"auto\")\n",
        "    model.eval()\n",
        "    return model\n",
        "\n",
        "if not QUANTIZED:\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        MODEL_PATH,\n",
        "        torch_dtype=TORCH_DTYPE,\n",
        "        device_map=device,\n",
        "        trust_remote_code=True,\n",
        "        # mem_attention_grouping is used\n",
        "        # to trade speed for memory usage\n",
        "        # for details, see the section Additional configuration\n",
        "        # in the Github repository\n",
        "        mem_attention_grouping=(1, 1024),\n",
        "    )\n",
        "    model.eval()\n",
        "else:\n",
        "    model = load_and_qunatize_model(8, MODEL_PATH)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ElblcG7MfOM4"
      },
      "source": [
        "# The demo"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "chPK2r8CfRcC"
      },
      "source": [
        "## Question answering on long documents\n",
        "Here we show the ability of the model to answer questions about long documents.   \n",
        "As it is a 7B parameter model it should be better than the [3B parameter one](https://colab.research.google.com/github/CStanKonrad/long_llama/blob/main/long_llama_instruct_colab.ipynb).\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Downloading and memory loading\n",
        "Code for downloading files and loading them into model memory.  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rt0PMfKDfLjy"
      },
      "outputs": [],
      "source": [
        "import urllib.request\n",
        "import tempfile\n",
        "import shutil\n",
        "import os\n",
        "\n",
        "\n",
        "def get_file(url: str, main_file: Optional[str] = None):\n",
        "    with tempfile.TemporaryDirectory() as tmp_dir:\n",
        "        file_path = os.path.join(tmp_dir, \"_paper.tmp\")\n",
        "        if main_file is not None:\n",
        "            # we are dealing with an archive\n",
        "            file_path += \".tar\"\n",
        "\n",
        "        urllib.request.urlretrieve(url, file_path)\n",
        "\n",
        "        if main_file is not None:\n",
        "            # we are dealing with an archive\n",
        "            shutil.unpack_archive(file_path, tmp_dir)\n",
        "            main_file_path = os.path.join(tmp_dir, main_file)\n",
        "        else:\n",
        "            main_file_path = file_path\n",
        "\n",
        "        with open(main_file_path, \"r\") as f:\n",
        "            data = f.read()\n",
        "\n",
        "    return data\n",
        "\n",
        "\n",
        "def get_paper(url: str, main_file: Optional[str] = None):\n",
        "    return get_file(url, main_file)\n",
        "\n",
        "\n",
        "def get_files(url_list: List[str]):\n",
        "    data = []\n",
        "    for url in url_list:\n",
        "        data.append(get_file(url))\n",
        "\n",
        "    data = \"\\n\".join(data)\n",
        "    return data\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def load_to_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, text: str):\n",
        "    tokenized_data = tokenizer(text, return_tensors=\"pt\")\n",
        "    input_ids = tokenized_data.input_ids\n",
        "    input_ids = input_ids.to(model.device)\n",
        "    torch.manual_seed(0)\n",
        "    output = model(input_ids=input_ids)\n",
        "    memory = output.past_key_values\n",
        "    return memory\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def generate_with_memory(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, memory, prompt: str, temperature=0.2):\n",
        "    tokenized_data = tokenizer(prompt, return_tensors=\"pt\")\n",
        "    input_ids = tokenized_data.input_ids\n",
        "    input_ids = input_ids.to(model.device)\n",
        "\n",
        "    streamer = TextStreamer(tokenizer, skip_prompt=False)\n",
        "\n",
        "    new_memory = memory\n",
        "\n",
        "    stop = False\n",
        "    while not stop:\n",
        "        output = model(input_ids, past_key_values=new_memory, last_context_length=3072)\n",
        "        new_memory = output.past_key_values\n",
        "        assert len(output.logits.shape) == 3\n",
        "        assert output.logits.shape[0] == 1\n",
        "        last_logit = output.logits[[0], [-1], :]\n",
        "        dist = torch.distributions.Categorical(logits=last_logit / temperature)\n",
        "        next_token = dist.sample()\n",
        "        if next_token[0] == tokenizer.eos_token_id:\n",
        "            streamer.put(next_token[None, :])\n",
        "            streamer.end()\n",
        "            stop = True\n",
        "        else:\n",
        "            input_ids = next_token[None, :]\n",
        "            streamer.put(input_ids)\n",
        "\n",
        "\n",
        "PROMPT_PREFIX = \"You are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can.\\n\\n\"\n",
        "\n",
        "\n",
        "def construct_question_prompt(question: str):\n",
        "    prompt = (\n",
        "        f\"\\nYou are an AI assistant. User will you give you a task. Your goal is to complete the task as faithfully as you can.\\n\\n\"\n",
        "        \"Answer he question below using the information from the text above.\\n\"\n",
        "        f\"Question: {question}\\nAnswer: \"\n",
        "    )\n",
        "    return prompt\n",
        "\n",
        "\n",
        "def ask_model(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, prompt: str, memory, seed=0):\n",
        "    tokenized_data = tokenizer(prompt, return_tensors=\"pt\")\n",
        "    input_ids = tokenized_data.input_ids\n",
        "    input_ids = input_ids.to(model.device)\n",
        "\n",
        "    torch.manual_seed(seed)\n",
        "    generate_with_memory(model, tokenizer, memory, prompt)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "try:\n",
        "    del chatbot\n",
        "except:\n",
        "    pass\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Questions about code\n",
        "We download the instruction tuning files from the long_llama repository and ask the model questions about the implementation.\n",
        "Each question is asked independently without updating the memory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "instruct_dp = get_files(\n",
        "    [\n",
        "        \"https://raw.githubusercontent.com/CStanKonrad/long_llama/2c88620d0ec9c28e13b4c208be34ebac68b90e37/instruction_fine_tuning/arguments.py\",\n",
        "        \"https://raw.githubusercontent.com/CStanKonrad/long_llama/2c88620d0ec9c28e13b4c208be34ebac68b90e37/instruction_fine_tuning/data_processing.py\",\n",
        "        \"https://raw.githubusercontent.com/CStanKonrad/long_llama/2c88620d0ec9c28e13b4c208be34ebac68b90e37/instruction_fine_tuning/fine_tuning.py\",\n",
        "    ]\n",
        ")\n",
        "try:\n",
        "    del fot_memory\n",
        "except:\n",
        "    pass\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()\n",
        "fot_memory = load_to_memory(model, tokenizer, PROMPT_PREFIX + instruct_dp)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What is the purpose of this code?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What is used for preparing the data? Name the most important functions and classes.\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"Can you say something more about `tokenize_text_no_special_tokens`?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"Can you say something more about `MixedTuneDataset`?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What are the main model configuration options? Enumerate them, use `` for their names.\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What are the options to configure the data? Enumerate them, use `` for their names.\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Questions about FoT paper\n",
        "We download the FoT paper and ask basic questions about it's content.  "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "14ao1rk6faHM"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "    del fot_memory\n",
        "except:\n",
        "    pass\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()\n",
        "fot_paper = get_paper(url=\"https://raw.githubusercontent.com/CStanKonrad/long_llama/main/assets/fot_paper.tar\", main_file=\"fot_paper.tex\")\n",
        "if QUANTIZED:\n",
        "    fot_paper = fot_paper[:50000]\n",
        "fot_memory = load_to_memory(model, tokenizer, PROMPT_PREFIX + fot_paper)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uz8Gja1Nfdkb"
      },
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What is the paper above about?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3tQfPSURfejw"
      },
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What method is introduced in the paper?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s0xg6aBAfgDG"
      },
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"How is the 3B model called by the authors?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sd3wPEzAfhFF"
      },
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"Name all six authors of the presented paper.\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6-mDT8TufjCl"
      },
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt(\"What is the distraction issue?\")\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "prompt = construct_question_prompt('What are the three main contributions of the paper?')\n",
        "ask_model(model, tokenizer, prompt, fot_memory)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Working with code\n",
        "\n",
        "The base [CodeLLaMA-7B](https://huggingface.co/codellama/CodeLlama-7b-hf) model was trained on a large amount of code data.    \n",
        "During the FoT tuning Python constituted a significant portion of the mixture.  \n",
        "For instruction tuning, we have utilized the [MathInstruct](https://huggingface.co/datasets/TIGER-Lab/MathInstruct) dataset that contains both Chain of Thought and Python of Thought examples.  \n",
        "Because of all of the above, the model can be used for manipulation of the coded data that includes but is not limited to\n",
        "* refactoring\n",
        "* rewriting to other languages\n",
        "* explaining  \n",
        "\n",
        "However, as this model has only 7B parameters it can still make simple mistakes.  \n",
        "Below we show some examples of how the model can be used."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Code for the chat interface\n",
        "Here, we provide the code for communicating with the model in an iterative way."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "class ChatOutputBuffer:\n",
        "    \"\"\"\n",
        "    For providing online output that\n",
        "    is truncated after generating specified (stop_text)\n",
        "    sequence of characters\n",
        "    \"\"\"\n",
        "\n",
        "    def __init__(self, stop_text: List[str], tokenizer: PreTrainedModel):\n",
        "        self.tokenizer = tokenizer\n",
        "        self.streamer = TextStreamer(tokenizer, skip_prompt=False)\n",
        "        self.max_stop_seq = 0\n",
        "        self.stop_seq = []\n",
        "        for st in stop_text:\n",
        "            self.stop_seq.append(st)\n",
        "            self.max_stop_seq = max(self.max_stop_seq, len(st))\n",
        "\n",
        "        self.output_buffer = np.empty((0,), dtype=np.int64)\n",
        "\n",
        "    def reset_output_buffer(self):\n",
        "        self.output_buffer = np.empty((0,), dtype=np.int64)\n",
        "\n",
        "    def advance_output(self):\n",
        "        beg = 0\n",
        "        end = len(self.output_buffer) - self.max_stop_seq\n",
        "\n",
        "        if end > beg:\n",
        "            output = self.output_buffer[beg:end]\n",
        "            self.streamer.put(output)\n",
        "            self.output_buffer = self.output_buffer[end:]\n",
        "\n",
        "    def flush_buffer(self):\n",
        "        if len(self.output_buffer) > 0:\n",
        "            self.streamer.put(self.output_buffer)\n",
        "            self.output_buffer = self.output_buffer[len(self.output_buffer) :]\n",
        "        self.streamer.end()\n",
        "\n",
        "    def generation_too_long(self, text: str) -> int:\n",
        "        end_requests = 0\n",
        "        for st in self.stop_seq:\n",
        "            if text.endswith(st):\n",
        "                end_requests += 1\n",
        "        return end_requests\n",
        "\n",
        "    def update_buffer(self, next_tok: int) -> bool:\n",
        "        assert isinstance(next_tok, int)\n",
        "\n",
        "        array_next_tok = np.array([next_tok], dtype=np.int64)\n",
        "        self.output_buffer = np.concatenate([self.output_buffer, array_next_tok], axis=0)\n",
        "\n",
        "        suffix = self.output_buffer[-self.max_stop_seq :]\n",
        "        decoded = self.tokenizer.decode(suffix)\n",
        "        end_requests = self.generation_too_long(decoded)\n",
        "        if end_requests > 0:\n",
        "            decoded = self.tokenizer.decode(suffix[1:])\n",
        "            while self.generation_too_long(decoded) == end_requests:\n",
        "                suffix = suffix[1:]\n",
        "                decoded = self.tokenizer.decode(suffix[1:])\n",
        "\n",
        "            left_intact = len(self.output_buffer) - len(suffix)\n",
        "\n",
        "            self.output_buffer = self.output_buffer[:left_intact]\n",
        "            self.flush_buffer()\n",
        "            return True\n",
        "\n",
        "        self.advance_output()\n",
        "        return False\n",
        "\n",
        "\n",
        "class SimpleChatBot:\n",
        "    def __init__(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer):\n",
        "        self.model = model\n",
        "        self.tokenizer = tokenizer\n",
        "        self.prompt = \"A chat between a user (denoted as USER:) and an artificial intelligence assistant (denoted as ASSISTANT:). The assistant gives helpful, detailed, and polite answers to the user's questions.\\n\\n\"\n",
        "        self.tokenized_prompt = self.tokenizer.encode(self.prompt, return_tensors=\"pt\", add_special_tokens=False)\n",
        "        self.tokenized_prompt = torch.concatenate(\n",
        "            [torch.tensor([self.tokenizer.bos_token_id], dtype=torch.long).reshape(1, 1), self.tokenized_prompt],\n",
        "            dim=-1,\n",
        "        )\n",
        "        self.model_name = \"\\nASSISTANT: \"\n",
        "        self.tokenized_model_name = self.tokenizer.encode(\n",
        "            self.model_name, return_tensors=\"pt\", add_special_tokens=False\n",
        "        )\n",
        "        self.user_name = \"\\nUSER: \"\n",
        "        self.tokenized_user_name = self.tokenizer.encode(self.user_name, return_tensors=\"pt\", add_special_tokens=False)\n",
        "        self.past_key_values = None\n",
        "\n",
        "        self.t = 0.2\n",
        "        self.output_buffer = ChatOutputBuffer(\n",
        "            [self.model_name.strip(), self.user_name.strip(), self.tokenizer.eos_token], self.tokenizer\n",
        "        )\n",
        "\n",
        "    @torch.no_grad()\n",
        "    def ask(self, text: str):\n",
        "        input_ids = self.tokenizer.encode(text, return_tensors=\"pt\", add_special_tokens=False)\n",
        "\n",
        "        input_ids = torch.concatenate([self.tokenized_user_name, input_ids, self.tokenized_model_name], dim=-1)\n",
        "\n",
        "        if self.past_key_values is None:\n",
        "            input_ids = torch.concatenate([self.tokenized_prompt, input_ids], dim=-1)\n",
        "\n",
        "        self.output_buffer.reset_output_buffer()\n",
        "        output_text = self.model_name\n",
        "        output_ids = self.tokenizer.encode(\n",
        "            output_text, return_tensors=\"pt\", add_special_tokens=self.past_key_values is None\n",
        "        )\n",
        "        self.output_buffer.streamer.put(output_ids)\n",
        "\n",
        "        is_writing = True\n",
        "\n",
        "        step_id = 0\n",
        "\n",
        "        while is_writing:\n",
        "            input_ids = input_ids.to(model.device)\n",
        "            output = self.model(input_ids=input_ids, past_key_values=self.past_key_values)\n",
        "\n",
        "            logits = output.logits\n",
        "            assert len(logits.shape) == 3\n",
        "            assert logits.shape[0] == 1\n",
        "            last_logit = logits[[0], [-1], :]\n",
        "\n",
        "            if step_id <= 2:\n",
        "                last_logit[..., tokenizer.eos_token_id] = -1e4\n",
        "\n",
        "            dist = torch.distributions.Categorical(logits=last_logit / self.t)\n",
        "            next_token = dist.sample()\n",
        "            # Note that parts of cut out text may remain in model memory\n",
        "            # this is implemented in this way for performance reasons\n",
        "            past_key_values = output.past_key_values\n",
        "            assert len(next_token.shape) == 1\n",
        "            should_stop = self.output_buffer.update_buffer(next_token[0].cpu().item())\n",
        "            if should_stop:\n",
        "                is_writing = False\n",
        "            else:\n",
        "                input_ids = next_token[None, :]\n",
        "                self.past_key_values = past_key_values\n",
        "\n",
        "            step_id += 1"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Ask to improve/refactor/rewritte\n",
        "You can ask the model to improve the code.  \n",
        "Note that this is only a 7B parameter model, so the results may contain errors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "try:\n",
        "    del fot_memory\n",
        "except:\n",
        "    pass\n",
        "gc.collect()\n",
        "torch.cuda.empty_cache()\n",
        "\n",
        "chatbot = SimpleChatBot(model=model, tokenizer=tokenizer)\n",
        "chatbot.ask('''\n",
        "#include <bits/stdc++.h>\n",
        "\n",
        "const int MAX_INPUT_LENGTH = 2*1000000;\n",
        "int _preff[MAX_INPUT_LENGTH + 7];    \n",
        "\n",
        "void _some_func(char *_text, int _dlugosc)\n",
        "{\n",
        "    _preff[0] = 0;\n",
        "    _preff[1] = 0;\n",
        "    int p = 0;\n",
        "    for (int i = 2; i <= _dlugosc; ++i)\n",
        "    {\n",
        "        while (p != 0 && _text[p + 1] != _text[i])\n",
        "            p = _preff[p];\n",
        "        \n",
        "        if (_text[p + 1] == _text[i])\n",
        "            ++p;\n",
        "        \n",
        "        _preff[i] = p;\n",
        "    }\n",
        "}\n",
        "\n",
        "Can you change the variable and function names in the code above so that they are more descriptive?''')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "chatbot.ask(\"Great! Can you rewrite it in Python? Change global arrays to local lists.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "When the code is simple the model can even improve it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "chatbot = SimpleChatBot(model=model, tokenizer=tokenizer)\n",
        "chatbot.ask('''\n",
        "def is_prime(x):\n",
        "    if x <= 1:\n",
        "        return False\n",
        "    for i in range(2, x):\n",
        "        if x % i == 0:\n",
        "            return False\n",
        "    return True\n",
        "            \n",
        "I have written the code above but it is pretty slow.\n",
        "Can you make it faster? Say O(sqrt(x)).''')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "chatbot.ask(\"Great! Can you rewrite this faster implementation it in C++?\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "chatbot.ask(\"Thanks! Can you add more comments?\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QPTv8WGAfklZ"
      },
      "source": [
        "## Chat\n",
        "We have also used [ShareGPT-Processed](https://huggingface.co/datasets/zetavg/ShareGPT-Processed) dataset to enhance the model conversation abilities. The chat prompt was inspired by [LongChat](https://github.com/DachengLi1/LongChat)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UYaj4mlIf7gi"
      },
      "source": [
        "Feel free to try the chat yourself:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qjd33BLlfph5"
      },
      "outputs": [],
      "source": [
        "chatbot = SimpleChatBot(model=model, tokenizer=tokenizer)\n",
        "while True:\n",
        "    user_text = input(\"USER: \")\n",
        "    chatbot.ask(user_text)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
